from pathlib import Path

import torch.distributed as dist
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Normalize, \
    RandomHorizontalFlip, Resize, ToTensor

from torch.utils.data import DataLoader
from data.FFPrivateDataset import PrivateDataset
from data.LMDB import LMDB


PrivateDatasetNames = ["FFHQ", "sunglasses", "babiesv2", "Raphael", "sunglasses_10",
                       "church", "haunted", "landscapes_art", "sketches", "Amedeo_Modigliani"]


def build_loader(config):
    dsets = dict()
    dset_loaders = dict()

    num_tasks = dist.get_world_size()
    global_rank = dist.get_rank()

    # get the source and target dataset
    if config.model.classifier.train:
        source_dataset_name = config.data.source_dataset_name

        dsets['train_source'] = build_dataset(config, source_dataset_name)
        print(f"local rank {config.local_rank} / global rank {dist.get_rank()} "
              f"successfully build {source_dataset_name} dataset")

        sampler_train_source = DistributedSampler(dsets['train_source'],
                                                  num_replicas=num_tasks,
                                                  rank=global_rank,
                                                  shuffle=True)

        dset_loaders['train_source'] = DataLoader(
            dataset=dsets['train_source'],
            sampler=sampler_train_source,
            batch_size=config.model.batch_size,
            num_workers=config.workers,
            pin_memory=config.pin_mem,
            drop_last=False,
            shuffle=False,
        )

    target_dataset_name = config.data.target_dataset_name

    dsets['train'] = build_dataset(config, target_dataset_name)
    print(f"local rank {config.local_rank} / global rank {dist.get_rank()} "
          f"successfully build {target_dataset_name} dataset")

    sampler_train = DistributedSampler(dsets['train'],
                                       num_replicas=num_tasks,
                                       rank=global_rank,
                                       shuffle=True)

    dset_loaders['train'] = DataLoader(
        dataset=dsets['train'],
        sampler=sampler_train,
        batch_size=config.model.batch_size,
        num_workers=config.workers,
        pin_memory=config.pin_mem,
        drop_last=False,
        shuffle=False,
    )


    return dsets, dset_loaders


def build_dataset(config, dataset_name):
    transform = build_transform(config)
    dataset = LMDB(root=Path("datasets").joinpath(dataset_name).joinpath(f"{dataset_name}.lmdb").__str__(),
                   transforms=transform)
    return dataset


def build_transform(config):
    """ transform image into tensor """
    # size = int((1 / 0.9) * config.data.img_size)
    transform = Compose([
        # Resize(size=[size, size], interpolation=InterpolationMode.BICUBIC),
        # CenterCrop(config.data.img_size),
        Resize(size=[config.data.img_size, config.data.img_size], interpolation=InterpolationMode.BICUBIC),
        RandomHorizontalFlip(p=config.aug.hflip),
        ToTensor(),  # turn into Numpy array of shape HWC, divide by 255
        Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize to -1, 1
    ])
    return transform
